{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from cot import Collection\n",
    "from cot.generate import FRAGMENTS\n",
    "from rich.pretty import pprint\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cot.create import create_thoughtsource_100, create_thoughtsource_1\n",
    "from sprint_utils import model_sprint_thoughtsource"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name_splits = [(\"commonsense_qa\", \"validation\"),\n",
    "                        (\"med_qa\", \"test\"),\n",
    "                        (\"medmc_qa\", \"validation\"),\n",
    "                        (\"open_book_qa\", \"test\"),\n",
    "                        (\"strategy_qa\", \"train\"),\n",
    "                        (\"worldtree\", \"test\")]\n",
    "\n",
    "# api_service_model = [(\"cohere\", \"command-xlarge-nightly\", 1)]\n",
    "# api_service_model = [(\"openai\", \"text-davinci-003\", 1)]\n",
    "# api_service_model = [(\"huggingface_hub\", \"google/flan-t5-xl\", 1)]\n",
    "# api_service_model = [(\"mock_api\", \"mock_model\", 0)]\n",
    "api_service_model = [(\"openai\", \"text-davinci-002\", 1)]\n",
    "cot_trigger_key = [None, \"kojima-01\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# collection = Collection.from_json(\"thoughtsource_data/thoughtsource_1_empty.json\")\n",
    "collection = Collection.from_json(\"thoughtsource_data/thoughtsource_100_empty.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "commonsense_qa\n",
      "Generating commonsense_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e9f0fe7d56a9455aa08fbdb644be83e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "med_qa\n",
      "Generating med_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6c0a3a2c24a64a71b982c6267c92d23e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "medmc_qa\n",
      "Generating medmc_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e9c4289d908a40cd955e06990f45c9c3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "open_book_qa\n",
      "Generating open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "96ab3591d3c04a55b7db8e8cb33a3225",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "strategy_qa\n",
      "Generating strategy_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b7e30d008828412a9b81d1b513b01a86",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "worldtree\n",
      "Generating worldtree...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "180f207966ef492d829a020f8a38e51c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating commonsense_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8ffd5bf8ca3b44b0be1e4e6893fd98ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating med_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d17fa07e16bb4ab6b925e7ac12bc15de",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating medmc_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9bdcc865a75541f1807eb8331c577c8f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f3e1ef92693648f7a9fbe8982e14626b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating strategy_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bee42f35aafb4969a57a0e64a8510023",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating worldtree...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ae41e5354d60419085dbcc38500c44a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "45f77ab211154a1da5664e9e4b40d5c7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "87a101dd1a6f4ce19df8b132ec8cd533",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3710b89ab2a0440399b5507b90de9dca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "beca04bb0b7145fa8b9ad8a8404fd62e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90ebaa26da614b7994f8234ff88cf334",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f31ae1d41e01440981e7152b29b6a674",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_sprint_thoughtsource(collection, dataset_name_splits, api_service_model, cot_trigger_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating commonsense_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "27cf6b3a1fd047d4b7562de9ea52bb8d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating med_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "83b58a5bc39140f6b8c51698aaa0e05a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating medmc_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "64cbc7b2b1f74a12a03bff4cdd6e090a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a09bd1498b3a49b1972c38e9fe462410",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating strategy_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6843d0cc8769440fbe0ee11f1156fefd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating worldtree...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "54a6bd2bfdd64d3bae04179e86480923",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'commonsense_qa'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'validation'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'text-davinci-002'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'None_None_kojima-A-E'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.76</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'None_kojima-01_kojima-A-E'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.62</span><span style=\"font-weight: bold\">}}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"font-weight: bold\">}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"font-weight: bold\">}</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'med_qa'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'test'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'text-davinci-002'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'None_None_kojima-A-E'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.41</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'None_kojima-01_kojima-A-E'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.34</span><span style=\"font-weight: bold\">}}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"font-weight: bold\">}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"font-weight: bold\">}</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'medmc_qa'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'validation'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'text-davinci-002'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'None_None_kojima-A-D'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.34</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'None_kojima-01_kojima-A-D'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.34</span><span style=\"font-weight: bold\">}}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"font-weight: bold\">}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"font-weight: bold\">}</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'open_book_qa'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'test'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'text-davinci-002'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'None_None_kojima-A-D'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.67</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'None_kojima-01_kojima-A-D'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.57</span><span style=\"font-weight: bold\">}}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"font-weight: bold\">}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"font-weight: bold\">}</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'strategy_qa'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'train'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   │   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'text-davinci-002'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'None_None_kojima-yes-no'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'None_kojima-01_kojima-yes-no'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.38</span><span style=\"font-weight: bold\">}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   │   </span><span style=\"font-weight: bold\">}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"font-weight: bold\">}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"font-weight: bold\">}</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'worldtree'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'test'</span>: <span style=\"font-weight: bold\">{</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   │   </span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'text-davinci-002'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'None_None_kojima-A-D'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.88</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'None_kojima-01_kojima-A-D'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.78</span><span style=\"font-weight: bold\">}}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"font-weight: bold\">}</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"font-weight: bold\">}</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'commonsense_qa'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[32m'validation'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   │   \u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'text-davinci-002'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'None_None_kojima-A-E'\u001b[0m: \u001b[1;36m0.76\u001b[0m, \u001b[32m'None_kojima-01_kojima-A-E'\u001b[0m: \u001b[1;36m0.62\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   \u001b[0m\u001b[1m}\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'med_qa'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[32m'test'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   │   \u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'text-davinci-002'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'None_None_kojima-A-E'\u001b[0m: \u001b[1;36m0.41\u001b[0m, \u001b[32m'None_kojima-01_kojima-A-E'\u001b[0m: \u001b[1;36m0.34\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   \u001b[0m\u001b[1m}\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'medmc_qa'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[32m'validation'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   │   \u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'text-davinci-002'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'None_None_kojima-A-D'\u001b[0m: \u001b[1;36m0.34\u001b[0m, \u001b[32m'None_kojima-01_kojima-A-D'\u001b[0m: \u001b[1;36m0.34\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   \u001b[0m\u001b[1m}\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'open_book_qa'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[32m'test'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   │   \u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'text-davinci-002'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'None_None_kojima-A-D'\u001b[0m: \u001b[1;36m0.67\u001b[0m, \u001b[32m'None_kojima-01_kojima-A-D'\u001b[0m: \u001b[1;36m0.57\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   \u001b[0m\u001b[1m}\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'strategy_qa'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[32m'train'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   │   \u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   │   │   \u001b[0m\u001b[32m'text-davinci-002'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'None_None_kojima-yes-no'\u001b[0m: \u001b[1;36m0.36\u001b[0m, \u001b[32m'None_kojima-01_kojima-yes-no'\u001b[0m: \u001b[1;36m0.38\u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   │   │   \u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   \u001b[0m\u001b[1m}\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'worldtree'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[32m'test'\u001b[0m: \u001b[1m{\u001b[0m\n",
       "\u001b[2;32m│   │   │   \u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'text-davinci-002'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'None_None_kojima-A-D'\u001b[0m: \u001b[1;36m0.88\u001b[0m, \u001b[32m'None_kojima-01_kojima-A-D'\u001b[0m: \u001b[1;36m0.78\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   │   \u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[2;32m│   \u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "eval = collection.evaluate()\n",
    "pprint(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cot.stats import evaluation_as_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>None</th>\n",
       "      <th>kojima-01</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>text-davinci-002</th>\n",
       "      <th>text-davinci-002</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>commonsense_qa</th>\n",
       "      <td>0.76</td>\n",
       "      <td>0.62</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>med_qa</th>\n",
       "      <td>0.41</td>\n",
       "      <td>0.34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmc_qa</th>\n",
       "      <td>0.34</td>\n",
       "      <td>0.34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>open_book_qa</th>\n",
       "      <td>0.67</td>\n",
       "      <td>0.57</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>strategy_qa</th>\n",
       "      <td>0.36</td>\n",
       "      <td>0.38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>worldtree</th>\n",
       "      <td>0.88</td>\n",
       "      <td>0.78</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                           None        kojima-01\n",
       "               text-davinci-002 text-davinci-002\n",
       "commonsense_qa             0.76             0.62\n",
       "med_qa                     0.41             0.34\n",
       "medmc_qa                   0.34             0.34\n",
       "open_book_qa               0.67             0.57\n",
       "strategy_qa                0.36             0.38\n",
       "worldtree                  0.88             0.78"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new = evaluation_as_table(eval)\n",
    "new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "old_coll = Collection.from_json(\"thoughtsource_data/thoughtsource_100_kojima_lievin_wei.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating commonsense_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7815898f5cc14eb5a7b6beb325c4db14",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating med_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fe204fa24e24417a9c418b745237d5e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating medmc_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b0c845efa8044b69cda49288736234f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3acffe07c6d34aba99397d239968579c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating strategy_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ff498e5f2abf43b5ba0be06ca20ccc9e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating worldtree...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b2bd3bd1d0214eda934a44f922faedfb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "old = evaluation_as_table(old_coll.evaluate(overwrite=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultiIndex([('kojima-01', 'text-davinci-002'),\n",
       "            ('lievin-01', 'text-davinci-002'),\n",
       "            ('lievin-02', 'text-davinci-002'),\n",
       "            ('lievin-03', 'text-davinci-002'),\n",
       "            ('lievin-10', 'text-davinci-002'),\n",
       "            (     'None', 'text-davinci-002')],\n",
       "           )"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "old.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "old.drop(columns=[('lievin-01', 'text-davinci-002'),\n",
    "            ('lievin-02', 'text-davinci-002'),\n",
    "            ('lievin-03', 'text-davinci-002'),\n",
    "            ('lievin-10', 'text-davinci-002')], inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>kojima-01</th>\n",
       "      <th>None</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>text-davinci-002</th>\n",
       "      <th>text-davinci-002</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>commonsense_qa</th>\n",
       "      <td>0.7</td>\n",
       "      <td>0.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>med_qa</th>\n",
       "      <td>0.46</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmc_qa</th>\n",
       "      <td>0.37</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>open_book_qa</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>strategy_qa</th>\n",
       "      <td>0.54</td>\n",
       "      <td>0.59</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>worldtree</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      kojima-01             None\n",
       "               text-davinci-002 text-davinci-002\n",
       "commonsense_qa              0.7             0.69\n",
       "med_qa                     0.46              NaN\n",
       "medmc_qa                   0.37              NaN\n",
       "open_book_qa                NaN              NaN\n",
       "strategy_qa                0.54             0.59\n",
       "worldtree                   NaN              NaN"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "old"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "old = old.reindex(columns=new.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>None</th>\n",
       "      <th>kojima-01</th>\n",
       "      <th>None</th>\n",
       "      <th>kojima-01</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>text-davinci-002</th>\n",
       "      <th>text-davinci-002</th>\n",
       "      <th>text-davinci-002</th>\n",
       "      <th>text-davinci-002</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>commonsense_qa</th>\n",
       "      <td>0.69</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.76</td>\n",
       "      <td>0.62</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>med_qa</th>\n",
       "      <td>NaN</td>\n",
       "      <td>0.46</td>\n",
       "      <td>0.41</td>\n",
       "      <td>0.34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmc_qa</th>\n",
       "      <td>NaN</td>\n",
       "      <td>0.37</td>\n",
       "      <td>0.34</td>\n",
       "      <td>0.34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>open_book_qa</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.67</td>\n",
       "      <td>0.57</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>strategy_qa</th>\n",
       "      <td>0.59</td>\n",
       "      <td>0.54</td>\n",
       "      <td>0.36</td>\n",
       "      <td>0.38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>worldtree</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.78</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                           None        kojima-01             None  \\\n",
       "               text-davinci-002 text-davinci-002 text-davinci-002   \n",
       "commonsense_qa             0.69              0.7             0.76   \n",
       "med_qa                      NaN             0.46             0.41   \n",
       "medmc_qa                    NaN             0.37             0.34   \n",
       "open_book_qa                NaN              NaN             0.67   \n",
       "strategy_qa                0.59             0.54             0.36   \n",
       "worldtree                   NaN              NaN             0.88   \n",
       "\n",
       "                      kojima-01  \n",
       "               text-davinci-002  \n",
       "commonsense_qa             0.62  \n",
       "med_qa                     0.34  \n",
       "medmc_qa                   0.34  \n",
       "open_book_qa               0.57  \n",
       "strategy_qa                0.38  \n",
       "worldtree                  0.78  "
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "pd.concat([old, new], axis=1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "ccbfa654f25866afe66a1c016d0b518a994fbe20a93c2d8b432dbe114020e2d2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
